{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c8b321fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import pandas as pd\n",
    "import networkx as nx\n",
    "from tqdm import tqdm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5b7c4888",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_data = {\n",
    "    'BASE_URI' : 'https://spoke.rbvi.ucsf.edu',\n",
    "    'cutoff_Compound_max_phase' : 3,\n",
    "    'cutoff_Protein_source' : ['SwissProt'],\n",
    "    'cutoff_DaG_diseases_sources' : ['knowledge', 'experiments'],\n",
    "    'cutoff_DaG_textmining' : 3,\n",
    "    'cutoff_CtD_phase' : 3,\n",
    "    'cutoff_PiP_confidence' : 0.7,\n",
    "    'cutoff_ACTeG_level' : ['Low', 'Medium', 'High'],\n",
    "    'cutoff_DpL_average_prevalence' : 0.001,\n",
    "    'depth' : 2\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "145fa12e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_spoke_api_resp(base_uri, end_point, params=None):\n",
    "    uri = base_uri + end_point\n",
    "    if params:\n",
    "        return requests.get(uri, params=params)\n",
    "    else:\n",
    "        return requests.get(uri)\n",
    "\n",
    "    \n",
    "def get_context_using_spoke_api(node_value):\n",
    "    type_end_point = \"/api/v1/types\"\n",
    "    result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point)\n",
    "    data_spoke_types = result.json()\n",
    "    node_types = list(data_spoke_types[\"nodes\"].keys())\n",
    "    edge_types = list(data_spoke_types[\"edges\"].keys())\n",
    "    node_types_to_remove = [\"DatabaseTimestamp\", \"Version\"]\n",
    "    filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove]\n",
    "    api_params = {\n",
    "        'node_filters' : filtered_node_types,\n",
    "        'edge_filters': edge_types,\n",
    "        'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'],\n",
    "        'cutoff_Protein_source': config_data['cutoff_Protein_source'],\n",
    "        'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'],\n",
    "        'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'],\n",
    "        'cutoff_CtD_phase': config_data['cutoff_CtD_phase'],\n",
    "        'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'],\n",
    "        'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level'],\n",
    "        'cutoff_DpL_average_prevalence': config_data['cutoff_DpL_average_prevalence'],\n",
    "        'depth' : config_data['depth']\n",
    "    }\n",
    "    node_type = \"Disease\"\n",
    "    attribute = \"name\"\n",
    "    nbr_end_point = \"/api/v1/neighborhood/{}/{}/{}\".format(node_type, attribute, node_value)\n",
    "    result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params)\n",
    "    node_context = result.json()\n",
    "    nbr_nodes = []\n",
    "    nbr_edges = []\n",
    "    for item in node_context:\n",
    "        if \"_\" not in item[\"data\"][\"neo4j_type\"]:\n",
    "            try:\n",
    "                if item[\"data\"][\"neo4j_type\"] == \"Protein\":\n",
    "                    nbr_nodes.append((item[\"data\"][\"neo4j_type\"], item[\"data\"][\"id\"], item[\"data\"][\"properties\"][\"description\"]))\n",
    "                else:\n",
    "                    nbr_nodes.append((item[\"data\"][\"neo4j_type\"], item[\"data\"][\"id\"], item[\"data\"][\"properties\"][\"name\"]))\n",
    "            except:\n",
    "                nbr_nodes.append((item[\"data\"][\"neo4j_type\"], item[\"data\"][\"id\"], item[\"data\"][\"properties\"][\"identifier\"]))\n",
    "        elif \"_\" in item[\"data\"][\"neo4j_type\"]:\n",
    "            try:\n",
    "                provenance = \", \".join(item[\"data\"][\"properties\"][\"sources\"])\n",
    "            except:\n",
    "                try:\n",
    "                    provenance = item[\"data\"][\"properties\"][\"source\"]\n",
    "                    if isinstance(provenance, list):\n",
    "                        provenance = \", \".join(provenance)                    \n",
    "                except:\n",
    "                    try:                    \n",
    "                        preprint_list = ast.literal_eval(item[\"data\"][\"properties\"][\"preprint_list\"])\n",
    "                        if len(preprint_list) > 0:                                                    \n",
    "                            provenance = \", \".join(preprint_list)\n",
    "                        else:\n",
    "                            pmid_list = ast.literal_eval(item[\"data\"][\"properties\"][\"pmid_list\"])\n",
    "                            pmid_list = map(lambda x:\"pubmedId:\"+x, pmid_list)\n",
    "                            if len(pmid_list) > 0:\n",
    "                                provenance = \", \".join(pmid_list)\n",
    "                            else:\n",
    "                                provenance = \"Based on data from Institute For Systems Biology (ISB)\"\n",
    "                    except:                                \n",
    "                        provenance = \"SPOKE-KG\"     \n",
    "            try:\n",
    "                evidence = item[\"data\"][\"properties\"]\n",
    "            except:\n",
    "                evidence = None\n",
    "            nbr_edges.append((item[\"data\"][\"source\"], item[\"data\"][\"neo4j_type\"], item[\"data\"][\"target\"], provenance, evidence))\n",
    "    nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=[\"node_type\", \"node_id\", \"node_name\"])\n",
    "    nbr_edges_df = pd.DataFrame(nbr_edges, columns=[\"source\", \"edge_type\", \"target\", \"provenance\", \"evidence\"])\n",
    "    merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on=\"source\", right_on=\"node_id\").drop(\"node_id\", axis=1)\n",
    "    merge_1.loc[:,\"node_name\"] = merge_1.node_type + \" \" + merge_1.node_name\n",
    "    merge_1.drop([\"source\", \"node_type\"], axis=1, inplace=True)\n",
    "    merge_1 = merge_1.rename(columns={\"node_name\":\"source\"})\n",
    "    merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on=\"target\", right_on=\"node_id\").drop(\"node_id\", axis=1)\n",
    "    merge_2.loc[:,\"node_name\"] = merge_2.node_type + \" \" + merge_2.node_name\n",
    "    merge_2.drop([\"target\", \"node_type\"], axis=1, inplace=True)\n",
    "    merge_2 = merge_2.rename(columns={\"node_name\":\"target\"})\n",
    "    merge_2 = merge_2[[\"source\", \"edge_type\", \"target\", \"provenance\", \"evidence\"]]\n",
    "    merge_2.loc[:, \"predicate\"] = merge_2.edge_type.apply(lambda x:x.split(\"_\")[0])\n",
    "    merge_2.loc[:, \"context\"] =  merge_2.source + \" \" + merge_2.predicate.str.lower() + \" \" + merge_2.target + \" and Provenance of this association is \" + merge_2.provenance + \".\"\n",
    "    context = merge_2.context.str.cat(sep=' ')\n",
    "    context += node_value + \" has a \" + node_context[0][\"data\"][\"properties\"][\"source\"] + \" identifier of \" + node_context[0][\"data\"][\"properties\"][\"identifier\"] + \" and Provenance of this is from \" + node_context[0][\"data\"][\"properties\"][\"source\"] + \".\"\n",
    "    return context, merge_2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2cc8b90c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 645 ms, sys: 125 ms, total: 770 ms\n",
      "Wall time: 11.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "node_name = 'amyloidosis'\n",
    "node_context,context_table = get_context_using_spoke_api(node_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a14fc76f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51093it [00:01, 46067.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.1 s, sys: 56.4 ms, total: 1.16 s\n",
      "Wall time: 1.16 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "graph = nx.Graph()\n",
    "\n",
    "for index, row in tqdm(context_table.iterrows()):\n",
    "    graph.add_edge(row['source'], row['target'], edge_type=row[\"predicate\"])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0765f43b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'edge_type': 'ASSOCIATES'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disease_node = 'Disease amyloidosis'\n",
    "graph[disease_node]['Gene APOE']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "98535843",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16.7 ms, sys: 1.02 ms, total: 17.8 ms\n",
      "Wall time: 17.6 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "extracted_path = []\n",
    "for neighbor_1 in graph.neighbors(disease_node):\n",
    "    if neighbor_1.startswith('Gene') | neighbor_1.startswith('Protein') | neighbor_1.startswith('Disease'):\n",
    "        for neighbor_2 in graph.neighbors(neighbor_1):\n",
    "            if neighbor_2.startswith('Compound'):\n",
    "                if graph[neighbor_1][neighbor_2]['edge_type'] != 'CONTRAINDICATES':\n",
    "                    extracted_path.append((disease_node, graph[disease_node][neighbor_1]['edge_type'], neighbor_1, \n",
    "                                           graph[neighbor_1][neighbor_2]['edge_type'], neighbor_2))\n",
    "            \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5a046cd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Disease amyloidosis',\n",
       "  'RESEMBLES',\n",
       "  'Disease cardiomyopathy',\n",
       "  'TREATS',\n",
       "  'Compound Dexrazoxane'),\n",
       " ('Disease amyloidosis',\n",
       "  'RESEMBLES',\n",
       "  'Disease cardiomyopathy',\n",
       "  'TREATS',\n",
       "  'Compound Prednisone'),\n",
       " ('Disease amyloidosis',\n",
       "  'ASSOCIATES',\n",
       "  'Gene APOE',\n",
       "  'DOWNREGULATES',\n",
       "  'Compound Alizapride'),\n",
       " ('Disease amyloidosis',\n",
       "  'ASSOCIATES',\n",
       "  'Gene APOE',\n",
       "  'DOWNREGULATES',\n",
       "  'Compound Proglumide'),\n",
       " ('Disease amyloidosis',\n",
       "  'ASSOCIATES',\n",
       "  'Gene APOE',\n",
       "  'DOWNREGULATES',\n",
       "  'Compound Idelalisib'),\n",
       " ('Disease amyloidosis',\n",
       "  'ASSOCIATES',\n",
       "  'Gene APOE',\n",
       "  'UPREGULATES',\n",
       "  'Compound Lorazepam')]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extracted_path[10:16]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "2dfd63be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current node: Disease amyloidosis, Path: ['Disease amyloidosis']\n"
     ]
    }
   ],
   "source": [
    "# %%time\n",
    "\n",
    "for path in find_connected_compounds(graph, \"Disease amyloidosis\"):\n",
    "#     print(path)\n",
    "    print(\" -> \".join(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "a265df27",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "'yield' outside function (3157857207.py, line 13)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[39], line 13\u001b[0;36m\u001b[0m\n\u001b[0;31m    yield path + [neighbor]\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m 'yield' outside function\n"
     ]
    }
   ],
   "source": [
    "disease_node = 'Disease amyloidosis'\n",
    "queue = deque([(disease_node, [disease_node])])\n",
    "\n",
    "while queue:\n",
    "    current_node, path = queue.popleft()\n",
    "\n",
    "    # Check if the current node is a Gene node\n",
    "    if current_node.startswith(\"Gene\"):\n",
    "        # If so, check its neighbors for Compound nodes\n",
    "        for neighbor in graph.neighbors(current_node):\n",
    "            if neighbor.startswith(\"Compound\"):\n",
    "                # If a Compound node is found, yield the path\n",
    "                yield path + [neighbor]\n",
    "            elif neighbor not in path:\n",
    "                # If a non-Compound node is found, add it to the queue\n",
    "                queue.append((neighbor, path + [neighbor]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "4d4e617f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "782eb8a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_function = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee365011",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_context_list = node_context.split(\". \")        \n",
    "node_context_embeddings = embedding_function.embed_documents(node_context_list)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
